import torch.nn as nn
from models.utils.activations import *

def get_activation_function(activation_type, inplace_flag=True):
    print('#' * 60)
    print('Model Activation Function : ' + activation_type)
    print('#' * 60)      

    while True:
        if activation_type == 'relu':
            yield nn.ReLU(inplace=inplace_flag)
        elif activation_type == 'sigmoid':
            yield nn.Sigmoid()
        elif activation_type == 'tanh':
            yield nn.Tanh()
        elif activation_type == 'softsign':
            yield nn.Softsign()
        elif activation_type == 'Hardtanh':
            yield nn.Hardtanh()
        elif activation_type == 'relu6':
            yield nn.ReLU6(inplace=inplace_flag)
        elif activation_type == 'elliot':
            yield Elliot()
        elif activation_type == 'lecun':
            yield LeCun_tanh()
        else:
            yield None

def bias_check(oper_order, ind):
    """
    If normalization in the BatchNorm comes after the weighted sum operation, 
    set bias_flag to False. 
    """

    bn_with_normalize_list = list("bNSMD")

    if ind == len(oper_order)-1:
        if oper_order[0] in bn_with_normalize_list:
            return False
        else:
            return True
    elif oper_order[ind+1] in bn_with_normalize_list:
        return False

    return True

def make_conv_block(channels_in=None, channels_out=None, activation_generator=None, oper_order='cba',
                    kernel_size=3, stride=1, padding=1, bn_momentum=0.1):
    '''
    In "arch" parameter

    'c' denotes 2d Convolution operation,
    'd' denotes depthwise convolution operation,
    'p' denotes pointwise convolution operation,
    'b' denotes 2d BatchNorm operation,
    'N' denotes 2d BatchNorm with only normalization operation,
    'G' denotes 2d BatchNorm with only gamma operation,
    'B' denotes 2d BatchNorm with only beta operation,
    'S' denotes 2d BatchNorm with normalization, gamma operation,
    'M' denotes 2d BatchNorm with normalization, beta operation,
    'T' denotes 2d BatchNorm with gamma, beta operation,
    'a' denotes activation operation
    '''

    layers = []
    channel = channels_in
    for ind, operation_type in enumerate(oper_order):
        if 'c' == operation_type:
            bias_flag = False #bias_check(oper_order, ind)
            l = nn.Conv2d(channel, channels_out, kernel_size=kernel_size, stride=stride,
                                            padding=padding, bias=bias_flag)
            channel = channels_out
        elif 'd' == operation_type:
            bias_flag = False # bias_check(oper_order, ind)

            l = nn.Conv2d(channel, channels_out, kernel_size=kernel_size, stride=stride, groups=channel,
                                            padding=padding, bias=bias_flag)
            channel = channels_out
        elif 'p' == operation_type:
            bias_flag = False # bias_check(oper_order, ind)

            l = nn.Conv2d(channel, channels_out, 1, bias=bias_flag)
            channel = channels_out

        elif 'b' == operation_type:
            l = nn.BatchNorm2d(channel, momentum=bn_momentum)
        elif 'D' == operation_type:
            l = nn.Dropout2d(p=0.5)
        elif 'l' == operation_type:
            l = Logging(channel)
        elif 'a' in oper_order:
            l = activation_generator.__next__()

        layers.append(l)

    return nn.Sequential(*layers)


def make_fc_block(activations_in=None, activations_out=None, 
                activation_generator=None, oper_order='lba', bn_momentum=0.1):
    '''
    In "arch" parameter

    'f' denotes fully-connected operation,
    'b' denotes 1d BatchNorm operation,
    'N' denotes 1d BatchNorm with only normalization operation,
    'G' denotes 1d BatchNorm with only gamma operation,
    'B' denotes 1d BatchNorm with only beta operation,
    'S' denotes 1d BatchNorm with normalization, gamma operation,
    'M' denotes 1d BatchNorm with normalization, beta operation,
    'T' denotes 1d BatchNorm with gamma, beta operation,
    'a' denotes activation operation

    '''
    layers = []
    activations = activations_in

    for ind, operation_type in enumerate(oper_order):
        if 'f' == operation_type or 'c' == operation_type:
            if 'c' == operation_type:
                bias_flag = True # bias_check(oper_order, ind)
                l = nn.Linear(activations, activations_out, bias=bias_flag)
            if 'f' == operation_type:
                bias_flag = True # bias_check(oper_order, ind)
                l = nn.Linear(activations, activations_out, bias=bias_flag)
            activations = activations_out
        if 'b' == operation_type:
            l = nn.BatchNorm1d(activations, momentum=bn_momentum)
        elif 'D' == operation_type:
            l = nn.Dropout(p=0.5)
        elif 'l' == operation_type:
            l = Logging(activations)
        elif 'a' == operation_type:
            l = activation_generator.__next__()

        layers.append(l)

    return nn.Sequential(*layers)
